class Solution:
    def countTriples(self, n: int) -> int:
        ans = 0
        nums = []
        for i in range(1, n + 1):
            num = i ** 2
            j, k = 0, len(nums) - 1
            while j <= k:
                if nums[j] + nums[k] > num:
                    k -= 1
                elif nums[j] + nums[k] < num:
                    j += 1
                else:
                    ans += 2
                    j += 1
                    k -= 1
            nums.append(num)
        return ans


if __name__ == "__main__":
    print(Solution().countTriples(5))  # 2
    print(Solution().countTriples(10))  # 4
